{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/sklearn/cross_validation.py:41: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.\n",
      "  \"This module will be removed in 0.20.\", DeprecationWarning)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from sklearn.utils import shuffle\n",
    "import re\n",
    "import time\n",
    "import collections\n",
    "import os\n",
    "import itertools\n",
    "from sklearn.cross_validation import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(words, n_words, atleast=1):\n",
    "    count = [['GO', 0], ['PAD', 1], ['EOS', 2], ['UNK', 3]]\n",
    "    counter = collections.Counter(words).most_common(n_words)\n",
    "    counter = [i for i in counter if i[1] >= atleast]\n",
    "    count.extend(counter)\n",
    "    dictionary = dict()\n",
    "    for word, _ in count:\n",
    "        dictionary[word] = len(dictionary)\n",
    "    data = list()\n",
    "    unk_count = 0\n",
    "    for word in words:\n",
    "        index = dictionary.get(word, 0)\n",
    "        if index == 0:\n",
    "            unk_count += 1\n",
    "        data.append(index)\n",
    "    count[0][1] = unk_count\n",
    "    reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))\n",
    "    return data, count, dictionary, reversed_dictionary\n",
    "\n",
    "def add_start_end(string):\n",
    "    string = string.split()\n",
    "    strings = []\n",
    "    for s in string:\n",
    "        s = list(s)\n",
    "        s[0] = '<%s'%(s[0])\n",
    "        s[-1] = '%s>'%(s[-1])\n",
    "        strings.extend(s)\n",
    "    return strings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000 10000\n"
     ]
    }
   ],
   "source": [
    "with open('lemmatization-en.txt','r') as fopen:\n",
    "    texts = fopen.read().split('\\n')\n",
    "after, before = [], []\n",
    "for i in texts[:10000]:\n",
    "    splitted = i.encode('ascii', 'ignore').decode(\"utf-8\").lower().split('\\t')\n",
    "    if len(splitted) < 2:\n",
    "        continue\n",
    "    after.append(add_start_end(splitted[0]))\n",
    "    before.append(add_start_end(splitted[1]))\n",
    "    \n",
    "print(len(after),len(before))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 65\n",
      "Most common words [('e', 9763), ('i', 7229), ('n', 6177), ('s>', 6095), ('o', 5769), ('a', 5633)]\n",
      "Sample data [46, 5, 11, 14, 35, 47, 4, 6, 10, 39] ['<f', 'i', 'r', 's', 't>', '<t', 'e', 'n', 't', 'h>']\n",
      "filtered vocab size: 69\n",
      "% of vocab used: 106.15%\n"
     ]
    }
   ],
   "source": [
    "concat_from = list(itertools.chain(*before))\n",
    "vocabulary_size_from = len(list(set(concat_from)))\n",
    "data_from, count_from, dictionary_from, rev_dictionary_from = build_dataset(concat_from, vocabulary_size_from)\n",
    "print('vocab from size: %d'%(vocabulary_size_from))\n",
    "print('Most common words', count_from[4:10])\n",
    "print('Sample data', data_from[:10], [rev_dictionary_from[i] for i in data_from[:10]])\n",
    "print('filtered vocab size:',len(dictionary_from))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_from)/vocabulary_size_from,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab from size: 90\n",
      "Most common words [('o', 5631), ('a', 5471), ('e', 5471), ('i', 4890), ('r', 4252), ('<c', 4134)]\n",
      "Sample data [83, 59, 57, 59, 55, 57, 59, 55, 55, 57] ['<1>', '<1', '0>', '<1', '0', '0>', '<1', '0', '0', '0>']\n",
      "filtered vocab size: 94\n",
      "% of vocab used: 104.44%\n"
     ]
    }
   ],
   "source": [
    "concat_to = list(itertools.chain(*after))\n",
    "vocabulary_size_to = len(list(set(concat_to)))\n",
    "data_to, count_to, dictionary_to, rev_dictionary_to = build_dataset(concat_to, vocabulary_size_to)\n",
    "print('vocab from size: %d'%(vocabulary_size_to))\n",
    "print('Most common words', count_to[4:10])\n",
    "print('Sample data', data_to[:10], [rev_dictionary_to[i] for i in data_to[:10]])\n",
    "print('filtered vocab size:',len(dictionary_to))\n",
    "print(\"% of vocab used: {}%\".format(round(len(dictionary_to)/vocabulary_size_to,4)*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "GO = dictionary_from['GO']\n",
    "PAD = dictionary_from['PAD']\n",
    "EOS = dictionary_from['EOS']\n",
    "UNK = dictionary_from['UNK']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(after)):\n",
    "    after[i].append('EOS')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Stemmer:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, \n",
    "                 from_dict_size, to_dict_size, learning_rate, \n",
    "                 dropout = 0.5, beam_width = 15):\n",
    "        \n",
    "        def lstm_cell(size, reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(size, initializer=tf.orthogonal_initializer(),\n",
    "                                           reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype=tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype=tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        # encoder\n",
    "        encoder_embeddings = tf.Variable(tf.random_uniform([from_dict_size, embedded_size], -1, 1))\n",
    "        encoder_embedded = tf.nn.embedding_lookup(encoder_embeddings, self.X)\n",
    "        for n in range(num_layers):\n",
    "            (out_fw, out_bw), (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(\n",
    "                cell_fw = lstm_cell(size_layer // 2),\n",
    "                cell_bw = lstm_cell(size_layer // 2),\n",
    "                inputs = encoder_embedded,\n",
    "                sequence_length = self.X_seq_len,\n",
    "                dtype = tf.float32,\n",
    "                scope = 'bidirectional_rnn_%d'%(n))\n",
    "            encoder_embedded = tf.concat((out_fw, out_bw), 2)\n",
    "        \n",
    "        bi_state_c = tf.concat((state_fw.c, state_bw.c), -1)\n",
    "        bi_state_h = tf.concat((state_fw.h, state_bw.h), -1)\n",
    "        bi_lstm_state = tf.nn.rnn_cell.LSTMStateTuple(c=bi_state_c, h=bi_state_h)\n",
    "        self.encoder_state = tuple([bi_lstm_state] * num_layers)\n",
    "        \n",
    "        self.encoder_state = tuple(self.encoder_state[-1] for _ in range(num_layers))\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        # decoder\n",
    "        decoder_embeddings = tf.Variable(tf.random_uniform([to_dict_size, embedded_size], -1, 1))\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([lstm_cell(size_layer) for _ in range(num_layers)])\n",
    "        dense_layer = tf.layers.Dense(to_dict_size)\n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(decoder_embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = decoder_embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = self.encoder_state,\n",
    "                output_layer = dense_layer)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        self.predicting_ids = predicting_decoder_output.sample_id\n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate).minimize(self.cost)\n",
    "        \n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 128\n",
    "num_layers = 2\n",
    "embedded_size = 64\n",
    "learning_rate = 1e-3\n",
    "batch_size = 32\n",
    "epoch = 15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Stemmer(size_layer, num_layers, embedded_size, len(dictionary_from), \n",
    "                len(dictionary_to), learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def str_idx(corpus, dic):\n",
    "    X = []\n",
    "    for i in corpus:\n",
    "        ints = []\n",
    "        for k in i:\n",
    "            try:\n",
    "                ints.append(dic[k])\n",
    "            except Exception as e:\n",
    "                ints.append(UNK)\n",
    "        X.append(ints)\n",
    "    return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = str_idx(before, dictionary_from)\n",
    "Y = str_idx(after, dictionary_to)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X, test_X, train_Y, test_Y = train_test_split(X, Y, test_size = 0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    padded_seqs = []\n",
    "    seq_lens = []\n",
    "    max_sentence_len = max([len(sentence) for sentence in sentence_batch])\n",
    "    for sentence in sentence_batch:\n",
    "        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))\n",
    "        seq_lens.append(len(sentence))\n",
    "    return padded_seqs, seq_lens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.82it/s, accuracy=0.709, cost=1.13] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:01<00:00, 61.55it/s, accuracy=0.734, cost=1.05] \n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 25.67it/s, accuracy=0.728, cost=0.995]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0, avg loss: 1.561948, avg accuracy: 0.606003\n",
      "epoch: 0, avg loss test: 1.034253, avg accuracy test: 0.734636\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.90it/s, accuracy=0.885, cost=0.45] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.57it/s, accuracy=0.824, cost=0.712]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 25.47it/s, accuracy=0.815, cost=0.687]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 1, avg loss: 0.787080, avg accuracy: 0.793143\n",
      "epoch: 1, avg loss test: 0.581813, avg accuracy test: 0.855047\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:09<00:00, 25.01it/s, accuracy=0.897, cost=0.438]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.57it/s, accuracy=0.924, cost=0.304]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 24.85it/s, accuracy=0.922, cost=0.32] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 2, avg loss: 0.456560, avg accuracy: 0.881484\n",
      "epoch: 2, avg loss test: 0.343726, avg accuracy test: 0.920939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.95it/s, accuracy=0.932, cost=0.281]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.49it/s, accuracy=0.908, cost=0.297]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:10, 24.01it/s, accuracy=0.956, cost=0.216]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 3, avg loss: 0.289203, avg accuracy: 0.928574\n",
      "epoch: 3, avg loss test: 0.235505, avg accuracy test: 0.947662\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.83it/s, accuracy=0.963, cost=0.153]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 67.04it/s, accuracy=0.966, cost=0.144]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:10, 24.59it/s, accuracy=0.946, cost=0.184]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 4, avg loss: 0.196689, avg accuracy: 0.952377\n",
      "epoch: 4, avg loss test: 0.192981, avg accuracy test: 0.959513\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.98it/s, accuracy=0.978, cost=0.117] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.63it/s, accuracy=0.938, cost=0.226]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 25.22it/s, accuracy=0.975, cost=0.113]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 5, avg loss: 0.147957, avg accuracy: 0.964733\n",
      "epoch: 5, avg loss test: 0.161298, avg accuracy test: 0.966909\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:09<00:00, 25.00it/s, accuracy=0.969, cost=0.115] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.81it/s, accuracy=0.989, cost=0.0534]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:08, 27.76it/s, accuracy=0.984, cost=0.0803]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 6, avg loss: 0.117273, avg accuracy: 0.972617\n",
      "epoch: 6, avg loss test: 0.112091, avg accuracy test: 0.981371\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.62it/s, accuracy=0.986, cost=0.0617]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.04it/s, accuracy=0.991, cost=0.0451]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 26.14it/s, accuracy=0.988, cost=0.0558]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 7, avg loss: 0.093180, avg accuracy: 0.978554\n",
      "epoch: 7, avg loss test: 0.107160, avg accuracy test: 0.979795\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.92it/s, accuracy=0.983, cost=0.0552]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.06it/s, accuracy=0.969, cost=0.131] \n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 26.41it/s, accuracy=0.983, cost=0.0612]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 8, avg loss: 0.077615, avg accuracy: 0.982014\n",
      "epoch: 8, avg loss test: 0.093358, avg accuracy test: 0.984920\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.84it/s, accuracy=0.989, cost=0.067] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.91it/s, accuracy=0.99, cost=0.0521] \n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 25.85it/s, accuracy=0.99, cost=0.0388] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 9, avg loss: 0.065573, avg accuracy: 0.984960\n",
      "epoch: 9, avg loss test: 0.091062, avg accuracy test: 0.985003\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.99it/s, accuracy=0.978, cost=0.0952]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 65.95it/s, accuracy=0.979, cost=0.0822]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 24.97it/s, accuracy=0.988, cost=0.0644]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 10, avg loss: 0.056420, avg accuracy: 0.987175\n",
      "epoch: 10, avg loss test: 0.078284, avg accuracy test: 0.988695\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:10<00:00, 24.81it/s, accuracy=0.982, cost=0.0684]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.22it/s, accuracy=0.952, cost=0.147] \n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 25.98it/s, accuracy=0.991, cost=0.0354]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 11, avg loss: 0.048207, avg accuracy: 0.988751\n",
      "epoch: 11, avg loss test: 0.069898, avg accuracy test: 0.990280\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:09<00:00, 25.22it/s, accuracy=0.99, cost=0.0473] \n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 65.70it/s, accuracy=0.988, cost=0.0439]\n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 25.96it/s, accuracy=0.99, cost=0.045]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 12, avg loss: 0.041016, avg accuracy: 0.991018\n",
      "epoch: 12, avg loss test: 0.065119, avg accuracy test: 0.992309\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:09<00:00, 25.09it/s, accuracy=0.989, cost=0.0527]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 66.77it/s, accuracy=0.99, cost=0.0346] \n",
      "train minibatch loop:   1%|          | 3/250 [00:00<00:09, 27.04it/s, accuracy=0.992, cost=0.0227]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 13, avg loss: 0.034818, avg accuracy: 0.992361\n",
      "epoch: 13, avg loss test: 0.065551, avg accuracy test: 0.991996\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "train minibatch loop: 100%|██████████| 250/250 [00:09<00:00, 25.10it/s, accuracy=0.998, cost=0.0183]\n",
      "test minibatch loop: 100%|██████████| 63/63 [00:00<00:00, 65.63it/s, accuracy=0.988, cost=0.0385]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 14, avg loss: 0.029981, avg accuracy: 0.993711\n",
      "epoch: 14, avg loss test: 0.067402, avg accuracy test: 0.990294\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "from sklearn.utils import shuffle\n",
    "import time\n",
    "\n",
    "for EPOCH in range(epoch):\n",
    "    lasttime = time.time()\n",
    "    total_loss, total_accuracy, total_loss_test, total_accuracy_test = 0, 0, 0, 0\n",
    "    train_X, train_Y = shuffle(train_X, train_Y)\n",
    "    test_X, test_Y = shuffle(test_X, test_Y)\n",
    "    pbar = tqdm(range(0, len(train_X), batch_size), desc='train minibatch loop')\n",
    "    for k in pbar:\n",
    "        index = min(k+batch_size,len(train_X))\n",
    "        batch_x, seq_x = pad_sentence_batch(train_X[k: k+batch_size], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(train_Y[k: k+batch_size], PAD)\n",
    "        acc, loss, _ = sess.run([model.accuracy, model.cost, model.optimizer], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss += loss\n",
    "        total_accuracy += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "        \n",
    "    pbar = tqdm(range(0, len(test_X), batch_size), desc='test minibatch loop')\n",
    "    for k in pbar:\n",
    "        index = min(k+batch_size,len(test_X))\n",
    "        batch_x, seq_x = pad_sentence_batch(test_X[k: k+batch_size], PAD)\n",
    "        batch_y, seq_y = pad_sentence_batch(test_Y[k: k+batch_size], PAD)\n",
    "        acc, loss = sess.run([model.accuracy, model.cost], \n",
    "                                      feed_dict={model.X:batch_x,\n",
    "                                                model.Y:batch_y})\n",
    "        total_loss_test += loss\n",
    "        total_accuracy_test += acc\n",
    "        pbar.set_postfix(cost=loss, accuracy = acc)\n",
    "        \n",
    "    total_loss /= (len(train_X) / batch_size)\n",
    "    total_accuracy /= (len(train_X) / batch_size)\n",
    "    total_loss_test /= (len(test_X) / batch_size)\n",
    "    total_accuracy_test /= (len(test_X) / batch_size)\n",
    "        \n",
    "    print('epoch: %d, avg loss: %f, avg accuracy: %f'%(EPOCH, total_loss, total_accuracy))\n",
    "    print('epoch: %d, avg loss test: %f, avg accuracy test: %f'%(EPOCH, total_loss_test, total_accuracy_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted = sess.run(model.predicting_ids, \n",
    "                     feed_dict={model.X:batch_x})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "row 1\n",
      "BEFORE: <chimerae>\n",
      "REAL AFTER: <chimera>\n",
      "PREDICTED AFTER: <chimerae> \n",
      "\n",
      "row 2\n",
      "BEFORE: <catheterisations>\n",
      "REAL AFTER: <catheterisation>\n",
      "PREDICTED AFTER: <catheterisation> \n",
      "\n",
      "row 3\n",
      "BEFORE: <consuming>\n",
      "REAL AFTER: <consume>\n",
      "PREDICTED AFTER: <consume> \n",
      "\n",
      "row 4\n",
      "BEFORE: <damped>\n",
      "REAL AFTER: <damp>\n",
      "PREDICTED AFTER: <damp> \n",
      "\n",
      "row 5\n",
      "BEFORE: <caraways>\n",
      "REAL AFTER: <caraway>\n",
      "PREDICTED AFTER: <caraway> \n",
      "\n",
      "row 6\n",
      "BEFORE: <aunts>\n",
      "REAL AFTER: <aunt>\n",
      "PREDICTED AFTER: <aunt> \n",
      "\n",
      "row 7\n",
      "BEFORE: <admirers>\n",
      "REAL AFTER: <admirer>\n",
      "PREDICTED AFTER: <admirer> \n",
      "\n",
      "row 8\n",
      "BEFORE: <contracts>\n",
      "REAL AFTER: <contract>\n",
      "PREDICTED AFTER: <contract> \n",
      "\n",
      "row 9\n",
      "BEFORE: <atoms>\n",
      "REAL AFTER: <atom>\n",
      "PREDICTED AFTER: <atom> \n",
      "\n",
      "row 10\n",
      "BEFORE: <chillier>\n",
      "REAL AFTER: <chilly>\n",
      "PREDICTED AFTER: <chilly> \n",
      "\n",
      "row 11\n",
      "BEFORE: <ameliorations>\n",
      "REAL AFTER: <amelioration>\n",
      "PREDICTED AFTER: <amelioration> \n",
      "\n",
      "row 12\n",
      "BEFORE: <contaminates>\n",
      "REAL AFTER: <contaminate>\n",
      "PREDICTED AFTER: <contaminaty> \n",
      "\n",
      "row 13\n",
      "BEFORE: <anti-heroes>\n",
      "REAL AFTER: <anti-hero>\n",
      "PREDICTED AFTER: <anti-hery> \n",
      "\n",
      "row 14\n",
      "BEFORE: <destines>\n",
      "REAL AFTER: <destine>\n",
      "PREDICTED AFTER: <destine> \n",
      "\n",
      "row 15\n",
      "BEFORE: <co-authored>\n",
      "REAL AFTER: <co-author>\n",
      "PREDICTED AFTER: <co-author> \n",
      "\n",
      "row 16\n",
      "BEFORE: <clobbers>\n",
      "REAL AFTER: <clobber>\n",
      "PREDICTED AFTER: <clobber> \n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(batch_x)):\n",
    "    print('row %d'%(i+1))\n",
    "    print('BEFORE:',''.join([rev_dictionary_from[n] for n in batch_x[i] if n not in [0,1,2,3]]))\n",
    "    print('REAL AFTER:',''.join([rev_dictionary_to[n] for n in batch_y[i] if n not in[0,1,2,3]]))\n",
    "    print('PREDICTED AFTER:',''.join([rev_dictionary_to[n] for n in predicted[i] if n not in[0,1,2,3]]),'\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
